{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "YSmlaedJpOjs"
   },
   "source": [
    "# Distributional Reinforcement Learning with Quantile Regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"https://ars-ashuha.github.io/images/QR-Net1.png\", width=1200>\n",
    "<img src=\"https://ars-ashuha.github.io/images/QR-Net2.png\", width=1200>\n",
    "\n",
    "- Distributional Reinforcement Learning with Quantile Regression, https://arxiv.org/pdf/1710.10044.pdf \n",
    "- The solution is you got stuck you could cheat a little bit https://github.com/ars-ashuha/quantile-regression-dqn-pytorch "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "qMQ0ozclpOju"
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "import torch\n",
    "import pickle\n",
    "import random\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "from logger import Logger\n",
    "from rl_utils import ReplayMemory, huber"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "JOzr85_ipOjy"
   },
   "outputs": [],
   "source": [
    "class Network(nn.Module):\n",
    "    def __init__(self, len_state, num_quant, num_actions):\n",
    "        nn.Module.__init__(self)\n",
    "        \n",
    "        self.num_quant = num_quant\n",
    "        self.num_actions = num_actions\n",
    "        \n",
    "        ###########################################################\n",
    "        ########         You Code should be here         ##########\n",
    "        # Define your model here, it is ok to use just \n",
    "        # two layers and tanh nonlinearity, do not forget that \n",
    "        # shape of the output should be \n",
    "        # batch_size x self.num_actions x self.num_quant\n",
    "        self.layer1 = ....\n",
    "        ###########################################################\n",
    "        \n",
    "    def forward(self, x):\n",
    "        ###########################################################\n",
    "        ########         You Code should be here         ##########\n",
    "        # Compute the output of the network\n",
    "        x = ....\n",
    "        return x\n",
    "        # Tensor of shape batch_size x self.num_actions x self.num_quant\n",
    "        ###########################################################\n",
    "    \n",
    "    def select_action(self, state, eps):\n",
    "        if not isinstance(state, torch.Tensor): \n",
    "            state = torch.Tensor([state])    \n",
    "            \n",
    "        action = torch.randint(0, 2, (1,))\n",
    "        if random.random() > eps:\n",
    "            ###########################################################\n",
    "            ########         You Code should be here         ##########\n",
    "            action = # Select Greedy action wrt Q(s, a) = E(Z(s, a))\n",
    "            ###########################################################\n",
    "        return int(action)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "iYIjEGQLpOj0"
   },
   "outputs": [],
   "source": [
    "# Here we've defined a schedule for exploration i.e. random action with prob eps\n",
    "eps_start, eps_end, eps_dec = 0.9, 0.1, 500 \n",
    "eps = lambda steps: eps_end + (eps_start - eps_end) * np.exp(-1. * steps / eps_dec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "urnaFCsrpOj3",
    "outputId": "90de4df1-f955-41a1-856b-f177ddba4dc2"
   },
   "outputs": [],
   "source": [
    "# We start from CartPole-v0         \n",
    "# and then will solve MountainCar-v0\n",
    "env_name = 'CartPole-v0' \n",
    "env = gym.make(env_name)\n",
    "\n",
    "memory = ReplayMemory(10000)\n",
    "logger = Logger('q-net', fmt={'loss': '.5f'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "DxAK_8ZKpOj8"
   },
   "outputs": [],
   "source": [
    "Z = # Define Z an approximation network \n",
    "Ztgt = # Define Z a target network \n",
    "optimizer = optim.Adam(Z.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "SvYi1J7ZpOj_"
   },
   "outputs": [],
   "source": [
    "tau = torch.Tensor((2 * np.arange(Z.num_quant) + 1) / (2.0 * Z.num_quant)).view(1, -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training cicle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 221
    },
    "colab_type": "code",
    "id": "97fqMzqhpOkC",
    "outputId": "7c59248c-8187-48a1-c7c9-97abc1d13f01"
   },
   "outputs": [],
   "source": [
    "gamma, batch_size = 0.99, 32 \n",
    "steps_done, running_reward = 0, None\n",
    "\n",
    "for episode in range(501): \n",
    "    sum_reward = 0\n",
    "    state = env.reset()\n",
    "    while True:\n",
    "        steps_done += 1\n",
    "        \n",
    "        action = Z.select_action(torch.Tensor([state]), eps(steps_done))\n",
    "        next_state, reward, done, _ = env.step(action)\n",
    "        memory.push(state, action, next_state, reward, float(done))\n",
    "        sum_reward += reward\n",
    "        \n",
    "        if len(memory) < batch_size: break    \n",
    "            \n",
    "        ###########################################################\n",
    "        ########         You Code should be here         ##########\n",
    "        # Sample transitions from Replay Memory\n",
    "        states, actions, rewards, next_states, dones = ...\n",
    "        ###########################################################\n",
    "        \n",
    "        ###########################################################\n",
    "        ########         You Code should be here         ##########\n",
    "        # Calculate quantiles theta for current state and actions\n",
    "        theta = ...\n",
    "        # Calculate quantiles for the next stage with target network \n",
    "        # and then take value for a max action\n",
    "        Znext_max = ...\n",
    "        Ttheta = rewards + gamma * (1 - dones) * Znext_max\n",
    "        # Calculate loss, use this trick to compute pairwise differences\n",
    "        # Trick Tensor of shape (3,2,1) minus Tensor of shape (1,2,3) is Tensor of shape (3, 2, 3)\n",
    "        # With all pairwise differences :)\n",
    "        # Use Huber elementwise function to compute Huber loss\n",
    "        diff = Ttheta.t().unsqueeze(-1) - theta \n",
    "        loss = torch.mean(huber(diff) * (tau - (diff.detach() < 0).float()).abs())\n",
    "        ###########################################################\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        state = next_state\n",
    "        \n",
    "        if steps_done % 100 == 0:\n",
    "            Ztgt.load_state_dict(Z.state_dict())\n",
    "            \n",
    "        if done and episode % 50 == 0:\n",
    "            logger.add(episode, steps=steps_done, running_reward=running_reward, loss=loss.data.numpy())\n",
    "            logger.iter_info()\n",
    "            \n",
    "        if done: \n",
    "            running_reward = sum_reward if not running_reward else 0.2 * sum_reward + running_reward*0.8\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "H5X4DVAqpOkG"
   },
   "source": [
    "# Vizualization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PWq4Wnxr74Yc"
   },
   "source": [
    "## Train the model for  MountainCar-v0 env here!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CEyateaApOkI"
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython import display\n",
    "\n",
    "%matplotlib inline\n",
    "import seaborn as sns\n",
    "sns.set_style('whitegrid')\n",
    "\n",
    "from matplotlib import rcParams\n",
    "rcParams['figure.figsize'] = 7, 2\n",
    "rcParams['figure.dpi'] = 150"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WtXI9QlbpOkK"
   },
   "outputs": [],
   "source": [
    "actions={\n",
    "    'CartPole-v0': ['Left', 'Right'],\n",
    "    'MountainCar-v0': ['Left', 'Non', 'Right'],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ijdMnuNEpOkP"
   },
   "outputs": [],
   "source": [
    "def get_plot(q):\n",
    "    eps, p = 1e-8, 0\n",
    "    x, y = [q[0]-np.abs(q[0]*0.2)], [0]\n",
    "    for i in range(0, len(q)):\n",
    "        x += [q[i]-eps, q[i]]\n",
    "        y += [p, p+1/len(q)]\n",
    "        p += 1/len(q)\n",
    "    x+=[q[i]+eps, q[i]+np.abs(q[i]*0.2)]\n",
    "    y+=[1.0, 1.0]\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "QUSc36rspOkS"
   },
   "outputs": [],
   "source": [
    "state, done, steps = env.reset(), False, 0\n",
    "while True:\n",
    "    plt.clf()\n",
    "    steps += 1\n",
    "    action = Z.select_action(torch.Tensor([state]), eps(steps_done))\n",
    "    state, reward, done, _ = env.step(action)\n",
    "    \n",
    "    if steps % 3 == 0:  \n",
    "        plt.subplot(1, 2, 1)\n",
    "        plt.title('step = %s' % steps)\n",
    "        plt.imshow(env.render(mode='rgb_array'))\n",
    "        plt.axis('off')\n",
    "\n",
    "        plt.subplot(1, 2, 2)\n",
    "        Zval = Z(torch.Tensor([state])).detach().numpy()\n",
    "        for i in range(env.action_space.n):\n",
    "            x, y = get_plot(Zval[0][i])\n",
    "            plt.plot(x, y, label='%s Q=%.1f' % (actions[env_name][i], Zval[0][i].mean()))\n",
    "            plt.legend(bbox_to_anchor=(1.1, 1.1), ncol=3, prop={'size': 6})\n",
    "\n",
    "        if done: break\n",
    "        display.clear_output(wait=True)\n",
    "        display.display(plt.gcf())\n",
    "plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "0gWdTEfApOkW"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "qr-dqn-solution-cool.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
